{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# First install textattack, tensorflow, and tensorflow_hub\n",
    "!pip install textattack tensorflow tensorflow_hub"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/v-zhukaijie/anaconda3/envs/promptbench/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "2023-12-24 03:45:05.172891: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
      "2023-12-24 03:45:05.172945: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
      "2023-12-24 03:45:05.173987: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
      "2023-12-24 03:45:05.180461: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2023-12-24 03:45:06.000286: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
     ]
    }
   ],
   "source": [
    "import promptbench as pb\n",
    "from promptbench.models import LLMModel\n",
    "from promptbench.prompt_attack import Attack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n",
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['textbugger', 'deepwordbug', 'textfooler', 'bertattack', 'checklist', 'stresstest', 'semantic']\n"
     ]
    }
   ],
   "source": [
    "# create model\n",
    "model_t5 = LLMModel(model='google/flan-t5-large')\n",
    "\n",
    "# create dataset\n",
    "dataset = pb.DatasetLoader.load_dataset(\"sst2\")\n",
    "\n",
    "# try part of the dataset\n",
    "dataset = dataset[:10]\n",
    "\n",
    "# create prompt\n",
    "prompt = \"As a sentiment classifier, determine whether the following text is 'positive' or 'negative'. Please classify: \\nQuestion: {content}\\nAnswer:\"\n",
    "\n",
    "# define the projection function required by the output process\n",
    "def proj_func(pred):\n",
    "    mapping = {\n",
    "        \"positive\": 1,\n",
    "        \"negative\": 0\n",
    "    }\n",
    "    return mapping.get(pred, -1)\n",
    "\n",
    "# define the evaluation function required by the attack\n",
    "def eval_func(prompt, dataset, model):\n",
    "    preds = []\n",
    "    labels = []\n",
    "    for d in dataset:\n",
    "        input_text = pb.InputProcess.basic_format(prompt, d)\n",
    "        raw_output = model(input_text)\n",
    "\n",
    "        output = pb.OutputProcess.cls(raw_output, proj_func)\n",
    "        preds.append(output)\n",
    "\n",
    "        labels.append(d[\"label\"])\n",
    "    \n",
    "    return pb.Eval.compute_cls_accuracy(preds, labels)\n",
    "    \n",
    "# define the unmodifiable words in the prompt\n",
    "# for example, the labels \"positive\" and \"negative\" are unmodifiable, and \"content\" is modifiable because it is a placeholder\n",
    "# if your labels are enclosed with '', you need to add \\' to the unmodifiable words (due to one feature of textattack)\n",
    "unmodifiable_words = [\"positive\\'\", \"negative\\'\", \"content\"]\n",
    "\n",
    "# print all supported attacks\n",
    "print(Attack.attack_list())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "These words (if they appear in the prompt) are not allowed to be attacked:\n",
      "[\"positive'\", \"negative'\", 'content']\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/v-zhukaijie/anaconda3/envs/promptbench/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:381: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------------------------------------------\n",
      "Current prompt is:  As a sentiment classifier, determine whether the following text is 'positive' or 'negative'. Please classify: \n",
      "Question: {content}\n",
      "Answer:\n",
      "Current accuracy is:  1.0\n",
      "--------------------------------------------------\n",
      "\n",
      "--------------------------------------------------\n",
      "Modifiable words:  ['As', 'a', 'sentiment', 'classifier', 'determine', 'whether', 'the', 'following', 'text', 'is', 'or', 'Please', 'classify', 'Question', 'Answer']\n",
      "--------------------------------------------------\n",
      "\n",
      "--------------------------------------------------\n",
      "Current prompt is:  As a sentiment classifier, determine whether the following text is 'positive' or 'negative'. Please classify: \n",
      "Question: {content}\n",
      "Answer  and false is not true :\n",
      "Current accuracy is:  1.0\n",
      "--------------------------------------------------\n",
      "\n",
      "--------------------------------------------------\n",
      "Current prompt is:  As a sentiment classifier, determine whether the following text is 'positive' or 'negative'. Please classify: \n",
      "Question: {content}\n",
      "Answer  and true is true  and true is true  and true is true  and true is true  and true is true :\n",
      "Current accuracy is:  1.0\n",
      "--------------------------------------------------\n",
      "\n",
      "--------------------------------------------------\n",
      "Current prompt is:  As a sentiment classifier, determine whether the following text is 'positive' or 'negative'. Please classify: \n",
      "Question: {content}\n",
      "Answer  and true is true :\n",
      "Current accuracy is:  1.0\n",
      "--------------------------------------------------\n",
      "\n",
      "{'original prompt': \"As a sentiment classifier, determine whether the following text is 'positive' or 'negative'. Please classify: \\nQuestion: {content}\\nAnswer:\", 'original score': 1.0, 'attacked prompt': \"As a sentiment classifier, determine whether the following text is 'positive' or 'negative'. Please classify: \\nQuestion: {content}\\nAnswer  and false is not true :\", 'attacked score': 1.0, 'PDR': 0.0}\n"
     ]
    }
   ],
   "source": [
    "# create attack, specify the model, dataset, prompt, evaluation function, and unmodifiable words\n",
    "# verbose=True means that the attack will print the intermediate results\n",
    "attack = Attack(model_t5, \"stresstest\", dataset, prompt, eval_func, unmodifiable_words, verbose=True)\n",
    "\n",
    "# print attack result\n",
    "print(attack.attack())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "promptbench",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
